//===-- IntrinsicCleaner.cpp ----------------------------------------------===//
//
//                    The Symbolic Slicer Library
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "Project.h"

using namespace llvm;

namespace {
    // this pass may add and delete module variables (via intrinsic lowering).
    class IntrinsicCleanerPass {
        static char ID;
        const llvm::DataLayout &DataLayout;
        llvm::IntrinsicLowering *IL;
        bool LowerIntrinsics;

        bool runOnBasicBlock(llvm::BasicBlock &b, llvm::Module &M);
    public:
        IntrinsicCleanerPass(const llvm::DataLayout &TD,
                             bool LI=true)
            : DataLayout(TD),
              IL(new llvm::IntrinsicLowering(TD)),
              LowerIntrinsics(LI) {}
        ~IntrinsicCleanerPass() { delete IL; }

        bool runOnModule(llvm::Module &M);
    };
}

bool IntrinsicCleanerPass::runOnModule(Module &M) {
    bool dirty = false;
    for (Module::iterator f = M.begin(), fe = M.end(); f != fe; ++f){
        for (Function::iterator b = f->begin(), be = f->end(); b != be; ++b){
            dirty |= runOnBasicBlock(*b, M);
        }
    }
    if (Function *Declare = M.getFunction("llvm.trap")) {
        Declare->eraseFromParent();
        dirty = true;
    }
    return dirty;
}

bool IntrinsicCleanerPass::runOnBasicBlock(BasicBlock &b, Module &M) {
    bool dirty = false;
    bool block_split=false;
    LLVMContext &ctx = M.getContext();

    unsigned WordSize = DataLayout.getPointerSizeInBits() / 8;
    for (BasicBlock::iterator i = b.begin(), ie = b.end();
         (i != ie) && (block_split == false);) {
        IntrinsicInst *ii = dyn_cast<IntrinsicInst>(&*i);
        // increment now since LowerIntrinsic deletion makes iterator invalid.
        ++i;
        if(ii) {
            switch (ii->getIntrinsicID()) {
            case Intrinsic::vastart:
            case Intrinsic::vaend:
                break;

                // Lower vacopy so that object resolution etc is handled by
                // normal instructions.
                //
                // FIXME: This is much more target dependent than just the word size,
                // however this works for x86-32 and x86-64.
            case Intrinsic::vacopy: { // (dst, src) -> *((i8**) dst) = *((i8**) src)
                Value *dst = ii->getArgOperand(0);
                Value *src = ii->getArgOperand(1);

                if (WordSize == 4) {
                    Type *i8pp = PointerType::getUnqual(PointerType::getUnqual(Type::getInt8Ty(ctx)));
                    Value *castedDst = CastInst::CreatePointerCast(dst, i8pp, "vacopy.cast.dst", ii);
                    Value *castedSrc = CastInst::CreatePointerCast(src, i8pp, "vacopy.cast.src", ii);
                    Value *load = new LoadInst(castedSrc, "vacopy.read", ii);
                    new StoreInst(load, castedDst, false, ii);
                } else {
                    assert(WordSize == 8 && "Invalid word size!");
                    Type *i64p = PointerType::getUnqual(Type::getInt64Ty(ctx));
                    Value *pDst = CastInst::CreatePointerCast(dst, i64p, "vacopy.cast.dst", ii);
                    Value *pSrc = CastInst::CreatePointerCast(src, i64p, "vacopy.cast.src", ii);
                    Value *val = new LoadInst(pSrc, std::string(), ii); new StoreInst(val, pDst, ii);
                    Value *off = ConstantInt::get(Type::getInt64Ty(ctx), 1);
                    pDst = GetElementPtrInst::Create(nullptr, pDst, off, std::string(), ii);
                    pSrc = GetElementPtrInst::Create(nullptr, pSrc, off, std::string(), ii);
                    val = new LoadInst(pSrc, std::string(), ii); new StoreInst(val, pDst, ii);
                    pDst = GetElementPtrInst::Create(nullptr, pDst, off, std::string(), ii);
                    pSrc = GetElementPtrInst::Create(nullptr, pSrc, off, std::string(), ii);
                    val = new LoadInst(pSrc, std::string(), ii); new StoreInst(val, pDst, ii);
                }
                ii->removeFromParent();
                delete ii;
                break;
            }

            case Intrinsic::sadd_with_overflow:
            case Intrinsic::ssub_with_overflow:
            case Intrinsic::smul_with_overflow:
            case Intrinsic::uadd_with_overflow:
            case Intrinsic::usub_with_overflow:
            case Intrinsic::umul_with_overflow: {
                IRBuilder<> builder(ii->getParent(), i);

                Value *op1 = ii->getArgOperand(0);
                Value *op2 = ii->getArgOperand(1);

                Value *result = 0;
                Value *result_ext = 0;
                Value *overflow = 0;
                string meta_data;
                unsigned int bw = op1->getType()->getPrimitiveSizeInBits();
                unsigned int bw2 = op1->getType()->getPrimitiveSizeInBits()*2;

                if ((ii->getIntrinsicID() == Intrinsic::uadd_with_overflow) ||
                    (ii->getIntrinsicID() == Intrinsic::usub_with_overflow) ||
                    (ii->getIntrinsicID() == Intrinsic::umul_with_overflow)) {

                    Value *op1ext =
                        builder.CreateZExt(op1, IntegerType::get(M.getContext(), bw2));
                    Value *op2ext =
                        builder.CreateZExt(op2, IntegerType::get(M.getContext(), bw2));
                    Value *int_max_s =
                        ConstantInt::get(op1->getType(), APInt::getMaxValue(bw));
                    Value *int_max =
                        builder.CreateZExt(int_max_s, IntegerType::get(M.getContext(), bw2));

                    if (ii->getIntrinsicID() == Intrinsic::uadd_with_overflow){
                        result_ext = builder.CreateAdd(op1ext, op2ext);
                        meta_data = "uadd";
                    } else if (ii->getIntrinsicID() == Intrinsic::usub_with_overflow){
                        result_ext = builder.CreateSub(op1ext, op2ext);
                        meta_data = "usub";
                    } else if (ii->getIntrinsicID() == Intrinsic::umul_with_overflow){
                        result_ext = builder.CreateMul(op1ext, op2ext);
                        meta_data = "umul";
                    }
                    overflow = builder.CreateICmpUGT(result_ext, int_max);

                } else if ((ii->getIntrinsicID() == Intrinsic::sadd_with_overflow) ||
                           (ii->getIntrinsicID() == Intrinsic::ssub_with_overflow) ||
                           (ii->getIntrinsicID() == Intrinsic::smul_with_overflow)) {

                    Value *op1ext =
                        builder.CreateSExt(op1, IntegerType::get(M.getContext(), bw2));
                    Value *op2ext =
                        builder.CreateSExt(op2, IntegerType::get(M.getContext(), bw2));
                    Value *int_max_s =
                        ConstantInt::get(op1->getType(), APInt::getSignedMaxValue(bw));
                    Value *int_min_s =
                        ConstantInt::get(op1->getType(), APInt::getSignedMinValue(bw));
                    Value *int_max =
                        builder.CreateSExt(int_max_s, IntegerType::get(M.getContext(), bw2));
                    Value *int_min =
                        builder.CreateSExt(int_min_s, IntegerType::get(M.getContext(), bw2));

                    if (ii->getIntrinsicID() == Intrinsic::sadd_with_overflow){
                        result_ext = builder.CreateAdd(op1ext, op2ext);
                        meta_data = "sadd";
                    } else if (ii->getIntrinsicID() == Intrinsic::ssub_with_overflow){
                        result_ext = builder.CreateSub(op1ext, op2ext);
                        meta_data = "ssub";
                    } else if (ii->getIntrinsicID() == Intrinsic::smul_with_overflow){
                        result_ext = builder.CreateMul(op1ext, op2ext);
                        meta_data = "smul";
                    }
                    overflow = builder.CreateOr(builder.CreateICmpSGT(result_ext, int_max),
                                                builder.CreateICmpSLT(result_ext, int_min));
                }

                // This trunc could be replaced by a more general trunc replacement
                // that allows to detect also undefined behavior in assignments or
                // overflow in operation with integers whose dimension is smaller than
                // int's dimension, e.g.
                //     uint8_t = uint8_t + uint8_t;
                // if one desires the wrapping should write
                //     uint8_t = (uint8_t + uint8_t) & 0xFF;
                // before this, must check if it has side effects on other operations
                result = builder.CreateTrunc(result_ext, op1->getType());
                Value *resultStruct = builder.CreateInsertValue(UndefValue::get(ii->getType()), result, 0);
                resultStruct = builder.CreateInsertValue(resultStruct, overflow, 1);
                Instruction* i_result = cast<Instruction>(resultStruct);
                i_result->setMetadata("kuboT_"+meta_data,llvm::MDNode::get(M.getContext(), None));

                ii->replaceAllUsesWith(resultStruct);
                ii->removeFromParent();
                delete ii;
                dirty = true;
                break;
            }

            case Intrinsic::dbg_value:
            case Intrinsic::dbg_declare:
                // Remove these regardless of lower intrinsics flag. This can
                // be removed once IntrinsicLowering is fixed to not have bad
                // caches.
                ii->eraseFromParent();
                dirty = true;
                break;

            case Intrinsic::trap: {
                // Intrisic instruction "llvm.trap" found. Directly lower it to
                // a call of the abort() function.
                Function *F = cast<Function>(M.getOrInsertFunction("abort", Type::getVoidTy(ctx), (size_t)NULL).getCallee());
                F->setDoesNotReturn();
                F->setDoesNotThrow();

                CallInst::Create(F, Twine(), ii);
                new UnreachableInst(ctx, ii);

                ii->eraseFromParent();

                dirty = true;
                break;
            }
            case Intrinsic::objectsize: {
                // We don't know the size of an object in general so we replace
                // with 0 or -1 depending on the second argument to the intrinsic.
                assert(ii->getNumArgOperands() == 4 && "wrong number of arguments");
                Value *minArg = ii->getArgOperand(1);
                assert(minArg && "Failed to get second argument");
                ConstantInt *minArgAsInt = dyn_cast<ConstantInt>(minArg);
                assert(minArgAsInt && "Second arg is not a ConstantInt");
                assert(minArgAsInt->getBitWidth() == 1 && "Second argument is not an i1");
                Value *replacement = NULL;
                IntegerType *intType = dyn_cast<IntegerType>(ii->getType());
                assert(intType && "intrinsic does not have integer return type");

                if (minArgAsInt->isZero()) {
                    // min=false
                    replacement = ConstantInt::get(intType, -1, /*isSigned=*/true);
                } else {
                    // min=true
                    replacement = ConstantInt::get(intType, 0, /*isSigned=*/false);
                }
                ii->replaceAllUsesWith(replacement);
                ii->eraseFromParent();
                dirty = true;
                break;
            }
            case Intrinsic::x86_sse2_max_pd:
            case Intrinsic::x86_sse2_max_sd:
            case Intrinsic::x86_sse_cmp_ps:
            case Intrinsic::x86_sse_rsqrt_ps:
            case Intrinsic::fabs: {
                dirty = true;
                break;
            }
            case Intrinsic::x86_sse2_psrai_d:
            // case Intrinsic::x86_sse2_storeu_dq:
            case Intrinsic::x86_sse2_cvtpd2ps: {
                //TODO: support this intrinsics
                break;
            }
            case Intrinsic::is_constant:{
                assert(ii->getNumArgOperands() == 1 && "wrong number of arguments");
                Value * val = ii->getArgOperand(0);
                Type * ty = val->getType();
                assert(isa<IntegerType>(ty));
                Value *replacement = NULL;
                IntegerType * boolTy = IntegerType::get(ctx,/*num of bits*/1);
                if(isa<ConstantInt>(val))
                    replacement=ConstantInt::getTrue(ctx);
                else
                    replacement=ConstantInt::getFalse(ctx);
                ii->replaceAllUsesWith(replacement);
                ii->eraseFromParent();
                dirty = true;
                break;
            }
            case Intrinsic::usub_sat:{
                IRBuilder<> builder(ii->getParent(), i);

                Value *op1 = ii->getArgOperand(0);
                Value *op2 = ii->getArgOperand(1);
                assert((op1->getType()->getPrimitiveSizeInBits() == \
                op2->getType()->getPrimitiveSizeInBits() && "ops for usub_sat intrinsics are not of the same bit length"));

                Value * result = builder.CreateSub(op1, op2);
                Value * shouldBeZero = builder.CreateICmpULT(op1, op2);
                Value * zeroValue = ConstantInt::get(ctx,APInt(op1->getType()->getPrimitiveSizeInBits(), 0));
                result = builder.CreateSelect(shouldBeZero,zeroValue,result);
                ii->replaceAllUsesWith(result);
                ii->removeFromParent();
                delete ii;
                dirty = true;
                break;
            }
            case Intrinsic::uadd_sat:{
                IRBuilder<> builder(ii->getParent(), i);

                Value *op1 = ii->getArgOperand(0);
                Value *op2 = ii->getArgOperand(1);
                Type * ty1 = op1->getType();
                Type * ty2 = op2->getType();
                assert((ty1->getPrimitiveSizeInBits() == \
                ty2->getPrimitiveSizeInBits() && "ops for uadd_sat intrinsics are not of the same bit length"));

                Value *result = builder.CreateAdd(op1, op2);
                uint64_t maxValue = 0xffffffffffffffff;
                // Constant::get will fit the maxValue to the type if it's not 64 bit long
                Value * maximunVal = ConstantInt::get(ty1,maxValue, false);
                Value * compare = builder.CreateICmpUGE(result,maximunVal);
                result = builder.CreateSelect(compare,maximunVal,result);
                ii->replaceAllUsesWith(result);
                ii->removeFromParent();
                delete ii;
                dirty = true;
                break;
            };
            case Intrinsic::read_register:{
                assert(ii->getNumArgOperands() == 1 && "wrong number of arguments");
                //errs() <<"Err: unhandled read register encountered"<<*ii<<'\n';
                dirty = true;
                break;
            }
            case Intrinsic::write_register:{
                assert(ii->getNumArgOperands() == 2 && "wrong number of arguments");
                //errs() <<"Err: unhandled read register encountered"<<*ii<<'\n';
                dirty = true;
                break;
            }

            default:
                if (LowerIntrinsics)
                    IL->LowerIntrinsicCall(ii);
                dirty = true;
                break;
            }
        }
    }

    return dirty;
}

#include "IntrinsicCleaner.h"

void cleanIntrinsics(Module &m, const DataLayout& dl) {
    IntrinsicCleanerPass icleaner(dl);
    icleaner.runOnModule(m);
}
